{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import mindspore\n",
    "from mindspore.dataset import text, GeneratorDataset, transforms\n",
    "from mindspore import nn\n",
    "\n",
    "from mindnlp import load_dataset\n",
    "\n",
    "from mindnlp.engine import Trainer, Evaluator\n",
    "from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback\n",
    "from mindnlp.metrics import Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "imdb_ds = load_dataset('imdb', split=['train', 'test'])\n",
    "imdb_train = imdb_ds['train']\n",
    "imdb_test = imdb_ds['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "imdb_train.get_dataset_size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=16, shuffle=False):\n",
    "    is_ascend = mindspore.get_context('device_target') == 'Ascend'\n",
    "    def tokenize(text):\n",
    "        if is_ascend:\n",
    "            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
    "        else:\n",
    "            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)\n",
    "        return tokenized['input_ids'], tokenized['attention_mask']\n",
    "\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(batch_size)\n",
    "\n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=[tokenize], input_columns=\"text\", output_columns=['input_ids', 'attention_mask'])\n",
    "    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns=\"label\", output_columns=\"labels\")\n",
    "    # batch dataset\n",
    "    if is_ascend:\n",
    "        dataset = dataset.batch(batch_size)\n",
    "    else:\n",
    "        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                             'attention_mask': (None, 0)})\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.transformers import GPTTokenizer\n",
    "# tokenizer\n",
    "gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')\n",
    "\n",
    "# add sepcial token: <PAD>\n",
    "special_tokens_dict = {\n",
    "    \"bos_token\": \"<bos>\",\n",
    "    \"eos_token\": \"<eos>\",\n",
    "    \"pad_token\": \"<pad>\",\n",
    "}\n",
    "num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train dataset into train and valid datasets\n",
    "imdb_train, imdb_val = imdb_train.split([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)\n",
    "dataset_val = process_dataset(imdb_val, gpt_tokenizer)\n",
    "dataset_test = process_dataset(imdb_test, gpt_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "next(dataset_train.create_tuple_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.transformers import GPTForSequenceClassification\n",
    "from mindspore.experimental.optim import Adam\n",
    "\n",
    "# set bert config and define parameters for training\n",
    "model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)\n",
    "model.config.pad_token_id = gpt_tokenizer.pad_token_id\n",
    "model.resize_token_embeddings(model.config.vocab_size + 3)\n",
    "\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)\n",
    "\n",
    "metric = Accuracy()\n",
    "\n",
    "# define callbacks to save checkpoints\n",
    "ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)\n",
    "best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)\n",
    "\n",
    "trainer = Trainer(network=model, train_dataset=dataset_train,\n",
    "                  eval_dataset=dataset_train, metrics=metric,\n",
    "                  epochs=3, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],\n",
    "                  jit=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)\n",
    "evaluator.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
